/**
 * Copyright (c) Huawei Technologies Co., Ltd. 2024. All rights reserved.
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "level2/aclnn_ffn.h"
#include "level2/aclnn_ffn_v2.h"
#include "level2/aclnn_ffn_v3.h"

#include <dlfcn.h>
#include <new>

#include "aclnn_kernels/transdata.h"
#include "level0/ffn.h"
#include "aclnn_kernels/contiguous.h"
#include "aclnn_kernels/reshape.h"
#include "acl/acl.h"
#include "aclnn/aclnn_base.h"
#include "aclnn_kernels/common/op_error_check.h"
#include "opdev/common_types.h"
#include "opdev/data_type_utils.h"
#include "opdev/format_utils.h"
#include "opdev/op_dfx.h"
#include "opdev/op_executor.h"
#include "opdev/op_log.h"
#include "opdev/platform.h"
#include "opdev/shape_utils.h"
#include "opdev/tensor_view_utils.h"
#include "opdev/make_op_executor.h"

using namespace op;

#ifdef __cplusplus
extern "C" {
#endif
extern "C" void NnopbaseSetUserHandle(void *executor, void *handle);
extern "C" void *NnopbaseGetUserHandle(void *executor);

static const int64_t DIM_LIMIT_UPPER = 8;
static const int64_t DIM_LIMIT_LOWER = 2;
static const int64_t HIGH_PERFORMANCE = 1;
static const int64_t HIGH_PRECISION = 0;
static const int64_t DIM_NUM_ONE = 1;
static const int64_t DIM_NUM_TWO = 2;
static const int64_t DIM_NUM_THREE = 3;
static const int64_t JUDGE_EVEN = 2;
static const int64_t MAX_K = 65535;
using DtypeCheck = std::initializer_list<op::DataType>;


enum class ActiveType {
  FASTGELU = 0,
  RELU,
  SILU,
  GELU,
  GEGLU,
  SWIGLU,
  REGLU,
  INVALID_TYPE
};

class ActiveMap {
  public:
  const char *activeName;
  ActiveType activeType;
};

constexpr class ActiveMap g_activeMap[] = {
    {"fastgelu", ActiveType::FASTGELU}, {"relu", ActiveType::RELU},   {"silu", ActiveType::SILU},
    {"gelu", ActiveType::GELU},         {"geglu", ActiveType::GEGLU}, {"swiglu", ActiveType::SWIGLU},
    {"reglu", ActiveType::REGLU},
};

struct FFNParams {
  const aclTensor *x = nullptr;
  const aclTensor *weight1 = nullptr;
  const aclTensor *weight2 = nullptr;
  const aclIntArray *expertTokensArr = nullptr;
  const aclTensor *expertTokens = nullptr;
  const aclTensor *bias1 = nullptr;
  const aclTensor *bias2 = nullptr;
  const aclTensor *scale = nullptr;
  const aclTensor *offset = nullptr;
  const aclTensor *deqScale1 = nullptr;
  const aclTensor *deqScale2 = nullptr;
  const aclTensor *antiquantScale1 = nullptr;
  const aclTensor *antiquantScale2 = nullptr;
  const aclTensor *antiquantOffset1 = nullptr;
  const aclTensor *antiquantOffset2 = nullptr;
  ActiveType activationType;
  int64_t innerPrecise;
  bool tokensIndexFlag;
  const aclTensor *y = nullptr;
};

static ActiveType GetActiveType(const char *activeType)
{
  for (const ActiveMap &item : g_activeMap) {
    size_t len = strlen(item.activeName);
    bool isValidActiveType = strlen(activeType) == len;
    // use for loop instead of strncasecmp to avoid possible out-of-bounds problems
    if (!isValidActiveType) {
      continue;
    }
    for (size_t i = 0; i < len; i++) {
      if (tolower(activeType[i]) != item.activeName[i]) {
        isValidActiveType = false;
        break;
      }
    }
    if (isValidActiveType) {
      OP_LOGD("aclnnFFN activeType is %s.", activeType);
      return item.activeType;
    }
  }
  return ActiveType::INVALID_TYPE;
}

static FVector<int64_t> GetShape(const aclTensor *tensor)
{
  FVector<int64_t> shape;
  if (tensor->GetViewShape().GetDimNum() == 0U) {
    shape.push_back(1);
  } else {
    size_t dimNum = tensor->GetViewShape().GetDimNum();
    for (size_t idx = 0U; idx < dimNum; idx++) {
      int64_t tmpVal = tensor->GetViewShape().GetDim(idx);
      shape.push_back(tmpVal);
    }
  }
  return shape;
}

static aclnnStatus CheckNotNull(const FFNParams &ffnParams)
{
  CHECK_COND(ffnParams.x != nullptr, ACLNN_ERR_PARAM_NULLPTR, "x must not be nullptr.");
  CHECK_COND(ffnParams.weight1 != nullptr, ACLNN_ERR_PARAM_NULLPTR, "weight1 must not be nullptr.");
  CHECK_COND(ffnParams.weight2 != nullptr, ACLNN_ERR_PARAM_NULLPTR, "weight2 must not be nullptr.");
  CHECK_COND(ffnParams.y != nullptr, ACLNN_ERR_PARAM_NULLPTR, "y must not be nullptr.");
  return ACLNN_SUCCESS;
}

static inline bool CheckY(const FFNParams &ffnParams)
{
  auto xDimNum = ffnParams.x->GetViewShape().GetDimNum();
  auto yDimNum = ffnParams.y->GetViewShape().GetDimNum();
  // y's dimNum should be equal to x's dimNum
  if (yDimNum != xDimNum) {
    OP_LOGE(ACLNN_ERR_PARAM_INVALID, "y's dimNum should be equal to x's dimNum, x dimNum is %zu, y dimNum is %zu",
            xDimNum, yDimNum);
    return false;
  }

  return true;
}

static bool CheckDtypeValidForFFNOp(const aclTensor *x, const aclTensor *weight1, const aclTensor *weight2)
{
  CHECK_COND(x->GetDataType() == DataType::DT_FLOAT16, false,
             "x only support dtype float16 without expert tokens, but found %s.",
             op::ToString(x->GetDataType()).GetString());
  CHECK_COND(weight1->GetDataType() == DataType::DT_FLOAT16, false,
             "weight1 only support dtype float16 without expert tokens, but found %s.",
             op::ToString(weight1->GetDataType()).GetString());
  CHECK_COND(weight2->GetDataType() == DataType::DT_FLOAT16, false,
             "weight2 only support dtype float16 without expert tokens, but found %s.",
             op::ToString(weight2->GetDataType()).GetString());
  return true;
}

bool CheckBiasDimNum(const aclTensor *weight, const aclTensor *expertTokens, const aclTensor *bias,
                     const char *biasName) {
  bool hasExperts = (expertTokens != nullptr);
  size_t weightSize = hasExperts ? DIM_NUM_THREE : DIM_NUM_TWO;
  // bias dim size should be 2 when having experTokens, and be 1 when not having it
  size_t biasSize = hasExperts ? DIM_NUM_TWO : DIM_NUM_ONE;
  auto biasDimNum = bias->GetViewShape().GetDimNum();
  if (biasDimNum != biasSize) {
    OP_LOGE(ACLNN_ERR_PARAM_INVALID,
            "%s DimNum dim should be (2: has experts, 1: no experts), %s DimNum dim is %zu",
            biasName, biasName, biasDimNum);
    return false;
  }
  size_t weightNDimIdx = weightSize - 1;
  int64_t weightNDimValue = weight->GetViewShape().GetDim(weightNDimIdx);
  size_t biasNDimIdx = biasSize - 1;
  int64_t biasNDimValue = bias->GetViewShape().GetDim(biasNDimIdx);
  if (biasNDimValue != weightNDimValue) {
    OP_LOGE(ACLNN_ERR_PARAM_INVALID,
            "%s last dim should be %ld, but %s last dim is %ld",
            biasName, weightNDimValue, biasName, biasNDimValue);
    return false;
  }
  if (hasExperts) {
    int64_t expertTokensLength = expertTokens->GetViewShape().GetDim(0);
    int64_t biasEDimValue = bias->GetViewShape().GetDim(0);
    if (biasEDimValue != expertTokensLength) {
      OP_LOGE(ACLNN_ERR_PARAM_INVALID,
              "%s first dim should be %ld, but %s first dim is %ld",
              biasName, expertTokensLength, biasName, biasEDimValue);
      return false;
    }
  }
  return true;
}

static bool CheckBias(const FFNParams &ffnParams)
{
  bool bias1CheckResult = true;
  bool bias2CheckResult = true;
  if (ffnParams.bias1 != nullptr) {
    bias1CheckResult = CheckBiasDimNum(ffnParams.weight1, ffnParams.expertTokens, ffnParams.bias1, "bias1");
  }
  if (ffnParams.bias2 != nullptr) {
    bias2CheckResult = CheckBiasDimNum(ffnParams.weight2, ffnParams.expertTokens, ffnParams.bias2, "bias2");
  }
  return bias1CheckResult && bias2CheckResult;
}

static bool GluShapeCheck(size_t weight1NDimValue, size_t weight2kDimValue)
{
  // remind by 2 to check if it is a odd number
  if (weight1NDimValue % 2 != 0) {
    OP_LOGE(ACLNN_ERR_PARAM_INVALID,
            "weight1 NDimValue should be even number in glu case, but your weight1 "
            "NDimVakue is %zu",
            weight1NDimValue);
    return false;
  }
  // weight2 KDimValue should be equal to half of weight1 NDimValue in glu case
  if (weight2kDimValue != weight1NDimValue / 2) {
    OP_LOGE(ACLNN_ERR_PARAM_INVALID,
            "weight2 KDimValue should be equal to half of weight1 NDimValue in glu case, "
            "but your weight2 kDimValue is %zu and your weight1 NDimValue is  %zu",
            weight2kDimValue, weight1NDimValue);
    return false;
  }
  return true;
}

aclnnStatus CheckFmapWeightShape(const FFNParams &ffnParams)
{
  auto xDimNum = ffnParams.x->GetViewShape().GetDimNum();
  auto xKdimIndex = xDimNum - 1;  // 1 represents the last dimension of x
  int64_t xKDimValue = ffnParams.x->GetViewShape().GetDim(xKdimIndex);
  // x's dim should within 2 ~ 8
  CHECK_COND(xDimNum >= DIM_LIMIT_LOWER && xDimNum <= DIM_LIMIT_UPPER, ACLNN_ERR_PARAM_INVALID,
             "x dim should within 2 ~ 8, but x dim is %zu", xDimNum);

  bool hasExperts = (ffnParams.expertTokens != nullptr);
  // weight dim size should be 3 when having experTokens, and be 2 when not having it
  size_t weightSize = hasExperts ? DIM_NUM_THREE : DIM_NUM_TWO;
  size_t weight1DimNum = ffnParams.weight1->GetViewShape().GetDimNum();
  size_t weight2DimNum = ffnParams.weight2->GetViewShape().GetDimNum();
  size_t kDimIdx = weightSize - 2;  // 2 represents the penultimate dimension of weight
  size_t nDimIdx = weightSize - 1;  // 1 represents the last dimension of weight
  int64_t weight1kDimValue = ffnParams.weight1->GetViewShape().GetDim(kDimIdx);
  int64_t weight1NDimValue = ffnParams.weight1->GetViewShape().GetDim(nDimIdx);
  int64_t weight2kDimValue = ffnParams.weight2->GetViewShape().GetDim(kDimIdx);
  int64_t weight2NDimValue = ffnParams.weight2->GetViewShape().GetDim(nDimIdx);
  CHECK_COND(weight1kDimValue <= MAX_K && weight2kDimValue <= MAX_K, ACLNN_ERR_PARAM_INVALID, "kDimValue of weight1 "
             "%ld and weight2 %ld should both be equal to or less than 65535", weight1kDimValue, weight2kDimValue);
  // 校验weight1, weight2维度
  CHECK_COND(weight1DimNum == weightSize, ACLNN_ERR_PARAM_INVALID,
             "weight1 DimNum dim should be (3: has experts, 2: no experts), weight1 DimNum dim is %zu", weight1DimNum);
  CHECK_COND(weight2DimNum == weightSize, ACLNN_ERR_PARAM_INVALID,
             "weight2 DimNum dim should be (3: has experts, 2: no experts), weight2 DimNum dim is %zu", weight2DimNum);

  CHECK_COND(xKDimValue == weight1kDimValue, ACLNN_ERR_PARAM_INVALID,
             "x'KDimValue is not equal to weight1'kDimValue, x's kDimValue is %ld but weight's "
             "kDimValue is %ld", xKDimValue, weight1kDimValue);

  bool isGlu = (static_cast<uint32_t>(ffnParams.activationType) >= static_cast<uint32_t>(ActiveType::GEGLU));
  if (isGlu) {
    CHECK_COND(GluShapeCheck(weight1NDimValue, weight2kDimValue), ACLNN_ERR_PARAM_INVALID, "GluShapeCheck failed.");
    return ACLNN_SUCCESS;
  }
  // 校验weight1, weight2中，k1=n2, n1=k2
  CHECK_COND(weight1kDimValue == weight2NDimValue, ACLNN_ERR_PARAM_INVALID,
             "weight1 KDimValue is not equal to weight2 NDimValue, weight1 KDimValue is %ld, "
             "but weight2 NDimValue is %ld", weight1kDimValue, weight2NDimValue);
  CHECK_COND(weight1NDimValue == weight2kDimValue, ACLNN_ERR_PARAM_INVALID,
             "weight1 NDimValue is not equal to weight2 KDimValue, weight1 NDimValue is %ld, "
             "but weight2 KDimValue is %ld",
             weight1NDimValue, weight2kDimValue);
  if (hasExperts) {
    int64_t expertsNum = ffnParams.expertTokens->GetViewShape().GetDim(0);
    CHECK_COND(ffnParams.weight1->GetViewShape().GetDim(0) == expertsNum, ACLNN_ERR_PARAM_INVALID,
               "weight1 length should be %ld, but weight1 length is %ld", expertsNum,
               ffnParams.weight1->GetViewShape().GetDim(0));
    CHECK_COND(ffnParams.weight2->GetViewShape().GetDim(0) == expertsNum, ACLNN_ERR_PARAM_INVALID,
               "weight2 length should be %ld, but weight2 length is %ld", expertsNum,
               ffnParams.weight2->GetViewShape().GetDim(0));
  }
  return ACLNN_SUCCESS;
}

aclnnStatus IsAntiQuantEmpty(const FFNParams &ffnParams)
{
  CHECK_RET(ffnParams.antiquantScale1 == nullptr, ACLNN_ERR_PARAM_INVALID);
  CHECK_RET(ffnParams.antiquantScale2 == nullptr, ACLNN_ERR_PARAM_INVALID);
  CHECK_RET(ffnParams.antiquantOffset1 == nullptr, ACLNN_ERR_PARAM_INVALID);
  CHECK_RET(ffnParams.antiquantOffset2 == nullptr, ACLNN_ERR_PARAM_INVALID);
  return ACLNN_SUCCESS;
}

aclnnStatus IsQuantEmpty(const FFNParams &ffnParams)
{
  CHECK_RET(ffnParams.scale == nullptr, ACLNN_ERR_PARAM_INVALID);
  CHECK_RET(ffnParams.offset == nullptr, ACLNN_ERR_PARAM_INVALID);
  CHECK_RET(ffnParams.deqScale1 == nullptr, ACLNN_ERR_PARAM_INVALID);
  CHECK_RET(ffnParams.deqScale2 == nullptr, ACLNN_ERR_PARAM_INVALID);
  return ACLNN_SUCCESS;
}

aclnnStatus CheckBiasDtype(const FFNParams &ffnParams, const DtypeCheck &biasSupportList)
{
  if (ffnParams.bias1 != nullptr) {
    OP_CHECK_DTYPE_NOT_SUPPORT(ffnParams.bias1, biasSupportList, return ACLNN_ERR_PARAM_INVALID);
  }
  if (ffnParams.bias2 != nullptr) {
    OP_CHECK_DTYPE_NOT_SUPPORT(ffnParams.bias2, biasSupportList, return ACLNN_ERR_PARAM_INVALID);
  }
  return ACLNN_SUCCESS;
}

aclnnStatus CheckQuantScaleAndOffset(const FFNParams &ffnParams, bool hasExperts, size_t weightNDimIdx)
{
  // 校验scale是否smooth 和 有无专家时scale和offset参数的维度数量和维度大小是否正确
  auto scaleDimNum = ffnParams.scale->GetViewShape().GetDimNum();
  int64_t weight1N1 = ffnParams.weight1->GetViewShape().GetDim(weightNDimIdx);
  int64_t scaleLengthDim0 = ffnParams.scale->GetViewShape().GetDim(0);
  DataType deqScaleDtype = ffnParams.deqScale1->GetDataType();
  int64_t requiredLength = 1;

  if (hasExperts) {
    requiredLength = ffnParams.expertTokens->GetViewShape().GetDim(0);
    CHECK_COND(scaleDimNum == DIM_NUM_TWO || scaleDimNum == DIM_NUM_ONE, ACLNN_ERR_PARAM_INVALID,
               "scale dim num should be 1 or 2 in quant scenario, but currently is %zu", scaleDimNum);
    CHECK_COND(scaleLengthDim0 == requiredLength, ACLNN_ERR_PARAM_INVALID,
               "scale length should be %ld, but scale length is %ld", requiredLength, scaleLengthDim0);
    if (scaleDimNum == DIM_NUM_TWO) {
      int64_t scaleLengthDim1 = ffnParams.scale->GetViewShape().GetDim(1);
      CHECK_COND(scaleLengthDim1 == weight1N1, ACLNN_ERR_PARAM_INVALID,
                 "scale length should be %ld, but scale length is %ld", weight1N1, scaleLengthDim1);
      CHECK_COND(deqScaleDtype == DataType::DT_UINT64 || deqScaleDtype == DataType::DT_BF16 ||
                 deqScaleDtype == DataType::DT_INT64, ACLNN_ERR_PARAM_INVALID,
                 "In per-channel mode, deqScaleDtype should be UINT64 or INT64 for output dtype is FLOAT16, "
                 "BFLOAT16 for output dtype is BFLOAT16, others are not supported.");
    }
  } else {
    CHECK_COND(scaleDimNum == 1 && (scaleLengthDim0 == 1 || scaleLengthDim0 == weight1N1), ACLNN_ERR_PARAM_INVALID,
               "scale dimNum should be 1 and dimValue should be equal to n1 or 1 when no experts, but currently dimNum is %zu, "
               "dimValue is %zu", scaleDimNum, scaleLengthDim0);
    if (scaleLengthDim0 == weight1N1 && weight1N1 != 1) {  // 1 represents the last dimension of weight1
      CHECK_COND(deqScaleDtype == DataType::DT_UINT64 || deqScaleDtype == DataType::DT_BF16 ||
                 deqScaleDtype == DataType::DT_INT64, ACLNN_ERR_PARAM_INVALID,
                 "In per-channel mode, deqScaleDtype should be UINT64 or INT64 for output dtype is FLOAT16, "
                 "BFLOAT16 for output dtype is BFLOAT16, others are not supported.");
    }
  }
  int64_t offsetLength = ffnParams.offset->GetViewShape().GetDim(0);
  CHECK_COND(offsetLength == requiredLength, ACLNN_ERR_PARAM_INVALID,
             "offset length should be %ld, but offset length is %ld", requiredLength, offsetLength);
  return ACLNN_SUCCESS;
}

aclnnStatus CheckQuantDimNumAndShape(const FFNParams &ffnParams)
{
  // 校验deq参数在有无专家场景下维度是否正确
  bool hasExperts = (ffnParams.expertTokens != nullptr);
  size_t dequantParamSize = hasExperts ? DIM_NUM_TWO : DIM_NUM_ONE;
  auto dequantScale1DimNum = ffnParams.deqScale1->GetViewShape().GetDimNum();
  auto dequantScale2DimNum = ffnParams.deqScale2->GetViewShape().GetDimNum();
  CHECK_COND(dequantScale1DimNum == dequantParamSize, ACLNN_ERR_PARAM_INVALID,
             "deqScale1 DimNum should be (2: has experts, 1: no experts), but deqScale1 DimNum is %zu",
             dequantScale1DimNum);
  CHECK_COND(dequantScale2DimNum == dequantParamSize, ACLNN_ERR_PARAM_INVALID,
             "deqScale2 DimNum should be (2: has experts, 1: no experts), but deqScale2 DimNum is %zu",
             dequantScale2DimNum);
  auto offsetDimNum = ffnParams.offset->GetViewShape().GetDimNum();
  CHECK_COND(offsetDimNum == 1, ACLNN_ERR_PARAM_INVALID, "offset DimNum should be 1, but offset DimNum is %zu",
             offsetDimNum);
  // 校验deq参数最后一维是否和weight最后一维一致
  size_t weightSize = hasExperts ? DIM_NUM_THREE : DIM_NUM_TWO;
  size_t weightNDimIdx = weightSize - 1;
  size_t dequantParamNDimIdx = dequantParamSize - 1;
  int64_t weight1NDimValue = ffnParams.weight1->GetViewShape().GetDim(weightNDimIdx);
  int64_t weight2NDimValue = ffnParams.weight2->GetViewShape().GetDim(weightNDimIdx);
  int64_t dequantScale1NDimValue = ffnParams.deqScale1->GetViewShape().GetDim(dequantParamNDimIdx);
  int64_t dequantScale2NDimValue = ffnParams.deqScale2->GetViewShape().GetDim(dequantParamNDimIdx);
  CHECK_COND(dequantScale1NDimValue == weight1NDimValue, ACLNN_ERR_PARAM_INVALID,
             "deqScale1 last dim should be %ld, but deqScale1 last dim is %ld", weight1NDimValue,
             dequantScale1NDimValue);
  CHECK_COND(dequantScale2NDimValue == weight2NDimValue, ACLNN_ERR_PARAM_INVALID,
             "deqScale2 last dim should be %ld, but deqScale2 last dim is %ld", weight2NDimValue,
             dequantScale2NDimValue);
  CHECK_COND(CheckQuantScaleAndOffset(ffnParams, hasExperts, weightNDimIdx) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected quant scale and offset params dim num or shape or dtype is not right!");
  // 校验有专家时deq参数的第一维是否和专家数一致
  if (hasExperts){
    int64_t dequantScale1EDimValue = ffnParams.deqScale1->GetViewShape().GetDim(0);
    int64_t dequantScale2EDimValue = ffnParams.deqScale2->GetViewShape().GetDim(0);
    int64_t expertTokensLength = ffnParams.expertTokens->GetViewShape().GetDim(0);
    CHECK_COND(dequantScale1EDimValue == expertTokensLength, ACLNN_ERR_PARAM_INVALID,
               "deqScale1 first dim should be %ld, but deqScale1 first dim is %ld", expertTokensLength,
               dequantScale1EDimValue);
    CHECK_COND(dequantScale2EDimValue == expertTokensLength, ACLNN_ERR_PARAM_INVALID,
               "deqScale2 first dim should be %ld, but deqScale2 first dim is %ld", expertTokensLength,
               dequantScale2EDimValue);
  }
  return ACLNN_SUCCESS;
}

aclnnStatus CheckQuant(const FFNParams &ffnParams)
{
  // 校验量化参数非空
  OP_CHECK_NULL(ffnParams.scale, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_NULL(ffnParams.offset, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_NULL(ffnParams.deqScale1, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_NULL(ffnParams.deqScale2, return ACLNN_ERR_PARAM_INVALID);

  DataType yDtype = ffnParams.y->GetDataType();
  const std::initializer_list<op::DataType> dequantDtypeSupportList =
    {op::DataType::DT_UINT64, op::DataType::DT_INT64, op::DataType::DT_FLOAT};
  if (yDtype == DataType::DT_BF16) {
    OP_CHECK_DTYPE_NOT_MATCH(ffnParams.deqScale1, op::DataType::DT_BF16, return ACLNN_ERR_PARAM_INVALID);
  } else {
    OP_CHECK_DTYPE_NOT_SUPPORT(ffnParams.deqScale1, dequantDtypeSupportList, return ACLNN_ERR_PARAM_INVALID);
  }
  OP_CHECK_DTYPE_NOT_SAME(ffnParams.deqScale1, ffnParams.deqScale2, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_DTYPE_NOT_SAME(ffnParams.weight1, ffnParams.weight2, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_DTYPE_NOT_MATCH(ffnParams.scale, op::DataType::DT_FLOAT, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_DTYPE_NOT_MATCH(ffnParams.offset, op::DataType::DT_FLOAT, return ACLNN_ERR_PARAM_INVALID);

  CHECK_COND(IsAntiQuantEmpty(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected quant input, but antiquant inputs is not empty!");
  CHECK_COND(CheckQuantDimNumAndShape(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected quant params dim num or shape or dtype is not right!");
  return ACLNN_SUCCESS;
}

aclnnStatus CheckHighPrecisionBF16(const FFNParams &ffnParams)
{
  OP_CHECK_DTYPE_NOT_SAME(ffnParams.weight1, ffnParams.weight2, return ACLNN_ERR_PARAM_INVALID);
  CHECK_COND(IsQuantEmpty(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected high precision, but quant inputs is not empty!");
  CHECK_COND(IsAntiQuantEmpty(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected high precision, but antiquant inputs is not empty!");
  CHECK_COND(CheckBiasDtype(ffnParams, {DataType::DT_FLOAT}) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected high precision, bias dtype is not right!");
  CHECK_COND(ffnParams.activationType != ActiveType::GEGLU, ACLNN_ERR_PARAM_INVALID,
             "Detected high precision, GEGLU is not supported!");
  return ACLNN_SUCCESS;
}

aclnnStatus CheckHighPrecisionFP16(const FFNParams &ffnParams)
{
  OP_CHECK_DTYPE_NOT_MATCH(ffnParams.weight1, op::DataType::DT_FLOAT16, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_DTYPE_NOT_SAME(ffnParams.weight1, ffnParams.weight2, return ACLNN_ERR_PARAM_INVALID);
  CHECK_COND(IsQuantEmpty(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected high precision, but quant inputs is not empty!");
  CHECK_COND(IsAntiQuantEmpty(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected high precision, but antiquant inputs is not empty!");
  OP_CHECK_DTYPE_NOT_SAME(ffnParams.x, ffnParams.weight1, return ACLNN_ERR_PARAM_INVALID);
  CHECK_COND(CheckBiasDtype(ffnParams, {DataType::DT_FLOAT16}) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected high precision, bias dtype is not right!");
  CHECK_COND(ffnParams.activationType != ActiveType::GEGLU, ACLNN_ERR_PARAM_INVALID,
             "Detected high precision, GEGLU is not supported!");
  return ACLNN_SUCCESS;
}

aclnnStatus CheckAntiQuantDimNum(const FFNParams &ffnParams)
{
  // 校验伪量化参数在有无专家场景下维度是否正确
  bool hasExperts = (ffnParams.expertTokens != nullptr);
  // per group num = 1
  size_t antiquantParamSize = hasExperts ? DIM_NUM_TWO : DIM_NUM_ONE;
  // per group num > 1
  size_t groupAntiquantParamSize = hasExperts ? DIM_NUM_THREE : DIM_NUM_TWO;
  auto antiquantScale1DimNum = ffnParams.antiquantScale1->GetViewShape().GetDimNum();
  auto antiquantOffset1DimNum = ffnParams.antiquantOffset1->GetViewShape().GetDimNum();
  auto antiquantScale2DimNum = ffnParams.antiquantScale2->GetViewShape().GetDimNum();
  auto antiquantOffset2DimNum = ffnParams.antiquantOffset2->GetViewShape().GetDimNum();

  CHECK_COND((antiquantScale1DimNum == antiquantParamSize) || (antiquantScale1DimNum == groupAntiquantParamSize),
             ACLNN_ERR_PARAM_INVALID,
             "antiquantScale1 DimNum should be (3 or 2: has experts, 2 or 1: no experts), but DimNum is %zu",
             antiquantScale1DimNum);
  CHECK_COND(antiquantScale1DimNum == antiquantOffset1DimNum, ACLNN_ERR_PARAM_INVALID,
             "antiquantOffset1 DimNum should be equal to antiquantScale1, but antiquantOffset1 DimNum is %zu",
             antiquantOffset1DimNum);
  CHECK_COND(antiquantScale1DimNum == antiquantScale2DimNum, ACLNN_ERR_PARAM_INVALID,
             "antiquantScale2 DimNum should be equal to antiquantScale1, but antiquantScale2 DimNum is %zu",
             antiquantScale2DimNum);
  CHECK_COND(antiquantScale2DimNum == antiquantOffset2DimNum, ACLNN_ERR_PARAM_INVALID,
             "antiquantOffset2 DimNum should be equal to antiquantScale2, but antiquantOffset2 DimNum is %zu",
             antiquantOffset2DimNum);
  for (size_t i = 0; i < antiquantScale1DimNum; i++) {
    int64_t curScaleShape = ffnParams.antiquantScale1->GetViewShape().GetDim(i);
    int64_t curOffsetShape = ffnParams.antiquantOffset1->GetViewShape().GetDim(i);
    CHECK_COND(curScaleShape == curOffsetShape, ACLNN_ERR_PARAM_INVALID,
               "antiquantScale1 shape should be equal to antiquantOffset1 shape, now is %ld vs %ld on dim %zu",
               curScaleShape, curOffsetShape, i);
  }
  for (size_t i = 0; i < antiquantScale2DimNum; i++) {
    int64_t curScaleShape = ffnParams.antiquantScale2->GetViewShape().GetDim(i);
    int64_t curOffsetShape = ffnParams.antiquantOffset2->GetViewShape().GetDim(i);
    CHECK_COND(curScaleShape == curOffsetShape, ACLNN_ERR_PARAM_INVALID,
               "antiquantScale2 shape should be equal to antiquantOffset2 shape, now is %ld vs %ld on dim %zu",
               curScaleShape, curOffsetShape, i);
  }
  return ACLNN_SUCCESS;
}

aclnnStatus CheckAntiQuantW4Shape(const FFNParams &ffnParams, const size_t weightNDimIdx)
{
  DataType w1Dtype = ffnParams.weight1->GetDataType();
  if (w1Dtype == DataType::DT_INT4) {
    int64_t weight1NDimValue = ffnParams.weight1->GetViewShape().GetDim(weightNDimIdx);
    int64_t weight2NDimValue = ffnParams.weight2->GetViewShape().GetDim(weightNDimIdx);
    CHECK_COND(weight1NDimValue % JUDGE_EVEN == 0, ACLNN_ERR_PARAM_INVALID,
               "when w1 is int4, the last dimension of the shape N1 %ld should be an even number!", weight1NDimValue);
    CHECK_COND(weight2NDimValue % JUDGE_EVEN == 0, ACLNN_ERR_PARAM_INVALID,
               "when w2 is int4, the last dimension of the shape N2 %ld should be an even number!", weight2NDimValue);
  }

  return ACLNN_SUCCESS;
}

aclnnStatus CheckAntiQuantShapeE(const FFNParams &ffnParams)
{
  if (ffnParams.expertTokens != nullptr){
    int64_t antiquantScale1EDimValue = ffnParams.antiquantScale1->GetViewShape().GetDim(0);
    int64_t antiquantOffset1EDimValue = ffnParams.antiquantOffset1->GetViewShape().GetDim(0);
    int64_t antiquantScale2EDimValue = ffnParams.antiquantScale2->GetViewShape().GetDim(0);
    int64_t antiquantOffset2EDimValue = ffnParams.antiquantOffset2->GetViewShape().GetDim(0);
    int64_t expertTokensLength = ffnParams.expertTokens->GetViewShape().GetDim(0);
    CHECK_COND(antiquantScale1EDimValue == expertTokensLength, ACLNN_ERR_PARAM_INVALID,
               "antiquantScale1 first dim should be %ld, but antiquantScale1 first dim is %ld", expertTokensLength,
               antiquantScale1EDimValue);
    CHECK_COND(antiquantScale2EDimValue == expertTokensLength, ACLNN_ERR_PARAM_INVALID,
               "antiquantScale2 first dim should be %ld, but antiquantScale2 first dim is %ld", expertTokensLength,
               antiquantScale2EDimValue);
    CHECK_COND(antiquantOffset1EDimValue == expertTokensLength, ACLNN_ERR_PARAM_INVALID,
               "antiquantOffset1 first dim should be %ld, but antiquantOffset1 first dim is %ld", expertTokensLength,
               antiquantOffset1EDimValue);
    CHECK_COND(antiquantOffset2EDimValue == expertTokensLength, ACLNN_ERR_PARAM_INVALID,
               "antiquantOffset2 first dim should be %ld, but antiquantOffset2 first dim is %ld", expertTokensLength,
               antiquantOffset2EDimValue);
  }
  return ACLNN_SUCCESS;
}

aclnnStatus CheckAntiQuantShape(const FFNParams &ffnParams)
{
  // 校验伪量化参数最后一维是否和weight最后一维一致
  size_t antiquantParamSize = ffnParams.antiquantScale1->GetViewShape().GetDimNum();
  size_t weightSize = ffnParams.expertTokens != nullptr ? DIM_NUM_THREE : DIM_NUM_TWO;
  size_t weightNDimIdx = weightSize - 1;
  size_t antiquantParamNDimIdx = antiquantParamSize - 1;
  int64_t weight1NDimValue = ffnParams.weight1->GetViewShape().GetDim(weightNDimIdx);
  int64_t weight2NDimValue = ffnParams.weight2->GetViewShape().GetDim(weightNDimIdx);
  int64_t antiquantScale1NDimValue = ffnParams.antiquantScale1->GetViewShape().GetDim(antiquantParamNDimIdx);
  int64_t antiquantOffset1NDimValue = ffnParams.antiquantOffset1->GetViewShape().GetDim(antiquantParamNDimIdx);
  int64_t antiquantScale2NDimValue = ffnParams.antiquantScale2->GetViewShape().GetDim(antiquantParamNDimIdx);
  int64_t antiquantOffset2NDimValue = ffnParams.antiquantOffset2->GetViewShape().GetDim(antiquantParamNDimIdx);
  CHECK_COND(
    antiquantScale1NDimValue == weight1NDimValue, ACLNN_ERR_PARAM_INVALID,
    "antiquantScale1 last dim should be %zu, but antiquantScale1 last dim is %zu", weight1NDimValue,
    antiquantScale1NDimValue);
  CHECK_COND(
    antiquantOffset1NDimValue == weight1NDimValue, ACLNN_ERR_PARAM_INVALID,
    "antiquantOffset1 last dim should be %zu, but antiquantOffset1 last dim is %zu", weight1NDimValue,
    antiquantOffset1NDimValue);
  CHECK_COND(
    antiquantScale2NDimValue == weight2NDimValue, ACLNN_ERR_PARAM_INVALID,
    "antiquantScale2 last dim should be %zu, but antiquantScale2 last dim is %zu", weight2NDimValue,
    antiquantScale2NDimValue);
  CHECK_COND(
    antiquantOffset2NDimValue == weight2NDimValue, ACLNN_ERR_PARAM_INVALID,
    "antiquantOffset2 last dim should be %zu, but antiquantOffset2 last dim is %zu", weight2NDimValue,
    antiquantOffset2NDimValue);

  CHECK_COND(CheckAntiQuantW4Shape(ffnParams, weightNDimIdx) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected antiquant w is int4, shape N is not an even number!");
  CHECK_COND(CheckAntiQuantShapeE(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected antiquant params dim0 and expert_tokens length are not equal");
  return ACLNN_SUCCESS;
}

aclnnStatus CheckAntiQuantPergroup(const FFNParams &ffnParams)
{
  bool hasExperts = (ffnParams.expertTokens != nullptr);
  size_t antiquantParamSize = ffnParams.antiquantScale1->GetViewShape().GetDimNum();
  if (hasExperts && antiquantParamSize == DIM_NUM_THREE || (!hasExperts) && antiquantParamSize == DIM_NUM_TWO) {
    size_t weightSize = ffnParams.weight1->GetViewShape().GetDimNum();
    size_t weightKDimIdx = weightSize - 2;
    int64_t weight1KDimValue = ffnParams.weight1->GetViewShape().GetDim(weightKDimIdx);
    int64_t weight2KDimValue = ffnParams.weight2->GetViewShape().GetDim(weightKDimIdx);
    DataType w1Dtype = ffnParams.weight1->GetDataType();
    DataType w2Dtype = ffnParams.weight2->GetDataType();
    size_t antiquantParamPergroupIdx = antiquantParamSize - 2;
    int64_t antiquantScale1GroupValue = ffnParams.antiquantScale1->GetViewShape().GetDim(antiquantParamPergroupIdx);
    int64_t antiquantScale2GroupValue = ffnParams.antiquantScale2->GetViewShape().GetDim(antiquantParamPergroupIdx);
    CHECK_COND(antiquantScale1GroupValue >= 1, ACLNN_ERR_PARAM_INVALID,
               "antiquantScale1GroupValue %zu must be greater than or equal to 1.", antiquantScale1GroupValue);
    CHECK_COND(antiquantScale2GroupValue >= 1, ACLNN_ERR_PARAM_INVALID,
               "antiquantScale2GroupValue %zu must be greater than or equal to 1.", antiquantScale2GroupValue);
    antiquantScale1GroupValue = antiquantScale1GroupValue > 1 ? antiquantScale1GroupValue : 1;
    antiquantScale2GroupValue = antiquantScale2GroupValue > 1 ? antiquantScale2GroupValue : 1;
    CHECK_COND((weight1KDimValue % antiquantScale1GroupValue) == 0, ACLNN_ERR_PARAM_INVALID,
               "weight1KDimValue %ld must be divisible by antiquantScale1GroupValue %ld, with no remainder.",
               weight1KDimValue, antiquantScale1GroupValue);
    CHECK_COND((weight2KDimValue % antiquantScale2GroupValue) == 0, ACLNN_ERR_PARAM_INVALID,
               "weight2KDimValue %ld must be divisible by antiquantScale2GroupValue %ld, with no remainder.",
               weight2KDimValue, antiquantScale2GroupValue);
    CHECK_COND((w1Dtype == DataType::DT_INT4) && (w2Dtype == DataType::DT_INT4), ACLNN_ERR_PARAM_INVALID,
               "weight dtype must be INT4.");
  }

  return ACLNN_SUCCESS;
}

aclnnStatus CheckAntiQuant(const FFNParams &ffnParams, const DtypeCheck &supportList, const DtypeCheck &biasSupportList)
{
  // 校验伪量化参数非空
  OP_CHECK_NULL(ffnParams.antiquantScale1, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_NULL(ffnParams.antiquantScale2, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_NULL(ffnParams.antiquantOffset1, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_NULL(ffnParams.antiquantOffset2, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_DTYPE_NOT_SAME(ffnParams.weight1, ffnParams.weight2, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_DTYPE_NOT_SUPPORT(ffnParams.antiquantScale1, supportList, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_DTYPE_NOT_SUPPORT(ffnParams.antiquantOffset1, supportList, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_DTYPE_NOT_SUPPORT(ffnParams.antiquantScale2, supportList, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_DTYPE_NOT_SUPPORT(ffnParams.antiquantOffset2, supportList, return ACLNN_ERR_PARAM_INVALID);
  CHECK_COND(IsQuantEmpty(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected antiquant input, but quant inputs is not empty!");
  CHECK_COND(CheckBiasDtype(ffnParams, biasSupportList) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected AntiQuant, bias dtype is not right!");
  CHECK_COND(CheckAntiQuantDimNum(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected antiquant params dim num is not right!");
  CHECK_COND(CheckAntiQuantShape(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected antiquant params shape is not right!");
  CHECK_COND(CheckAntiQuantPergroup(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected antiquant params per group num or weight dtype is not right!");
  return ACLNN_SUCCESS;
}

aclnnStatus CheckGlu(const FFNParams &ffnParams)
{
  CHECK_RET(CheckDtypeValidForFFNOp(ffnParams.x, ffnParams.weight1, ffnParams.weight2), ACLNN_ERR_PARAM_INVALID);
  CHECK_COND(ffnParams.innerPrecise == HIGH_PERFORMANCE, ACLNN_ERR_PARAM_INVALID,
             "Detected glu, inner precision is not right!");
  CHECK_COND(ffnParams.expertTokens == nullptr, ACLNN_ERR_PARAM_INVALID,
             "Detected glu, expert tokens should be nullptr, but current it is not nullptr!");
  CHECK_COND(CheckBiasDtype(ffnParams, {DataType::DT_FLOAT16}) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected glu, bias dtype is not right!");
  return ACLNN_SUCCESS;
}

aclnnStatus CheckHighPerformance(const FFNParams &ffnParams)
{
  OP_CHECK_DTYPE_NOT_MATCH(ffnParams.weight1, op::DataType::DT_FLOAT16, return ACLNN_ERR_PARAM_INVALID);
  OP_CHECK_DTYPE_NOT_SAME(ffnParams.weight1, ffnParams.weight2, return ACLNN_ERR_PARAM_INVALID);
  CHECK_COND(IsQuantEmpty(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected quant input, but quant inputs is not empty!");
  CHECK_COND(IsAntiQuantEmpty(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected antiquant input, but quant inputs is not empty!");
  CHECK_COND(CheckBiasDtype(ffnParams, {DataType::DT_FLOAT16}) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
             "Detected HighPerformance, bias dtype is not right!");
  return ACLNN_SUCCESS;
}

aclnnStatus CheckOptionalParams(FFNParams &ffnParams)
{
  DataType xDtype = ffnParams.x->GetDataType();
  DataType w1Dtype = ffnParams.weight1->GetDataType();

  const DtypeCheck xSupportList = {DataType::DT_FLOAT16, DataType::DT_BF16, DataType::DT_INT8};
  OP_CHECK_DTYPE_NOT_SUPPORT(ffnParams.x, xSupportList, return ACLNN_ERR_PARAM_INVALID);

  OP_CHECK_DTYPE_NOT_SAME(ffnParams.weight1, ffnParams.weight2, return ACLNN_ERR_PARAM_INVALID);
  const DtypeCheck wSupportList = {DataType::DT_FLOAT16, DataType::DT_BF16, DataType::DT_INT8, DataType::DT_INT4};
  OP_CHECK_DTYPE_NOT_SUPPORT(ffnParams.weight1, wSupportList, return ACLNN_ERR_PARAM_INVALID);

  if (static_cast<uint32_t>(ffnParams.activationType) >= static_cast<uint32_t>(ActiveType::GEGLU)) {
    return CheckGlu(ffnParams);  // glu, mixed operator
  }

  if (xDtype == DataType::DT_BF16) {
    OP_CHECK_DTYPE_NOT_MATCH(ffnParams.y, op::DataType::DT_BF16, return ACLNN_ERR_PARAM_INVALID);
    if (w1Dtype == DataType::DT_BF16) {
      CHECK_COND(ffnParams.innerPrecise == HIGH_PRECISION, ACLNN_ERR_PARAM_INVALID,
                 "FFN bfloat16 is only support high precison now!");
      return CheckHighPrecisionBF16(ffnParams);  // HIGH_PRECISION_BF16
    } else if (w1Dtype == DataType::DT_INT8 || w1Dtype == DataType::DT_INT4) {
      return CheckAntiQuant(ffnParams, {DataType::DT_BF16}, {DataType::DT_FLOAT});  // ANTIQUANT_BF16_W8+W4
    } else {
      return ACLNN_ERR_PARAM_INVALID;
    }
  }

  DataType yDtype = ffnParams.y->GetDataType();
  if (xDtype == DataType::DT_INT8) {
    OP_CHECK_DTYPE_NOT_MATCH(ffnParams.weight1, op::DataType::DT_INT8, return ACLNN_ERR_PARAM_INVALID);
    CHECK_COND(CheckBiasDtype(ffnParams, {DataType::DT_INT32}) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID,
               "Detected QUANT, bias dtype is not right!");
    if (yDtype == DataType::DT_BF16 || yDtype == DataType::DT_FLOAT16) {
      return CheckQuant(ffnParams); //QUANT
    } else {
      return ACLNN_ERR_PARAM_INVALID;
    }
  }

  OP_CHECK_DTYPE_NOT_MATCH(ffnParams.y, op::DataType::DT_FLOAT16, return ACLNN_ERR_PARAM_INVALID);

  // mmDataType SHOULD BE DataType::DT_FLOAT16
  CHECK_COND(xDtype == DataType::DT_FLOAT16, ACLNN_ERR_PARAM_INVALID, "Detected xDtype is not right!");
  if (w1Dtype == DataType::DT_INT8 || w1Dtype == DataType::DT_INT4) {
    return CheckAntiQuant(ffnParams, {DataType::DT_FLOAT16}, {DataType::DT_FLOAT16}); // ANTIQUANT_FP16_W8+W4
  }
  if (ffnParams.innerPrecise == HIGH_PRECISION) {
    return CheckHighPrecisionFP16(ffnParams);  // HIGH_PRECISION_FP16
  }
  return CheckHighPerformance(ffnParams);  // HIGH_PERFORMANCE, expertTokens!=nullptr
}

static aclnnStatus CheckFormat(const FFNParams &ffnParams) {
  bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P;
  op::Format xFormat = ffnParams.x->GetStorageFormat();
  op::Format weight1Format = ffnParams.weight1->GetStorageFormat();
  op::Format weight2Format = ffnParams.weight2->GetStorageFormat();
  op::Format yFormat = ffnParams.y->GetStorageFormat();
  bool isXFormatValid = xFormat < Format::FORMAT_END &&
    (!op::IsPrivateFormat(xFormat) || (is310P && xFormat == Format::FORMAT_FRACTAL_NZ));
  CHECK_COND(isXFormatValid, ACLNN_ERR_PARAM_INVALID, "format of x %s is invalid.", op::ToString(xFormat).GetString());

  bool isWeight1FormatValid = weight1Format < Format::FORMAT_END &&
    (!op::IsPrivateFormat(weight1Format) || (is310P && weight1Format == Format::FORMAT_FRACTAL_NZ));
  CHECK_COND(isWeight1FormatValid, ACLNN_ERR_PARAM_INVALID, "format of weight1 %s is invalid.",
             op::ToString(weight1Format).GetString());

  bool isWeight2FormatValid = weight2Format < Format::FORMAT_END &&
    (!op::IsPrivateFormat(weight2Format) || (is310P && weight2Format == Format::FORMAT_FRACTAL_NZ));
  CHECK_COND(isWeight2FormatValid, ACLNN_ERR_PARAM_INVALID, "format of weight2 %s is invalid.",
             op::ToString(weight2Format).GetString());

  bool isYFormatValid = yFormat < Format::FORMAT_END &&
    (!op::IsPrivateFormat(yFormat) || (is310P && yFormat == Format::FORMAT_FRACTAL_NZ));
  CHECK_COND(isYFormatValid, ACLNN_ERR_PARAM_INVALID, "format of y %s is invalid.", op::ToString(yFormat).GetString());

  return ACLNN_SUCCESS;
}

inline static aclnnStatus CheckParam(FFNParams &ffnParams)
{
  CHECK_COND(CheckNotNull(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_NULLPTR, "one of required inputs is nullptr.");
  CHECK_COND(CheckFormat(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID, "invalid format.");
  CHECK_RET(CheckY(ffnParams), ACLNN_ERR_PARAM_INVALID);
  CHECK_RET(CheckFmapWeightShape(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID);
  CHECK_RET(CheckBias(ffnParams), ACLNN_ERR_PARAM_INVALID);
  if (GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P) {
    auto xDimNum = ffnParams.x->GetViewShape().GetDimNum();
    CHECK_COND(xDimNum == DIM_LIMIT_LOWER, ACLNN_ERR_PARAM_INVALID,
               "FFN only supports x dim 2 on ASCEND310P, but x dim is %zu", xDimNum);
    CHECK_COND(ffnParams.innerPrecise == HIGH_PERFORMANCE, ACLNN_ERR_PARAM_INVALID,
               "FFN is only support high performance now!");
    CHECK_COND(ffnParams.activationType <= ActiveType::GELU, ACLNN_ERR_PARAM_INVALID,
               "now activation types supported by ffn are fastgelu, gelu, relu, silu, please check activation.");
    OP_CHECK_DTYPE_NOT_MATCH(ffnParams.x, op::DataType::DT_FLOAT16, return ACLNN_ERR_PARAM_INVALID);
    OP_CHECK_DTYPE_NOT_MATCH(ffnParams.y, op::DataType::DT_FLOAT16, return ACLNN_ERR_PARAM_INVALID);
    CHECK_COND(ffnParams.expertTokens == nullptr, ACLNN_ERR_PARAM_INVALID,
               "FFN only supports no expert cases on ASCEND310P.");
    CHECK_RET(CheckHighPerformance(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID);
  } else {
    CHECK_RET(ffnParams.activationType != ActiveType::INVALID_TYPE, ACLNN_ERR_PARAM_INVALID);
    CHECK_RET(CheckOptionalParams(ffnParams) == ACLNN_SUCCESS, ACLNN_ERR_PARAM_INVALID);
  }

  return ACLNN_SUCCESS;
}

aclnnStatus InputsContiguousAndTransFormat(const aclTensor *tensor, const aclTensor *&reformatedTensor,
                                           const std::string &tensorName, aclOpExecutor *executor) {
  if (tensor == nullptr) {
    return ACLNN_SUCCESS;
  }
  op::Format tensorFormat = tensor->GetStorageFormat();
  if (tensorFormat != Format::FORMAT_FRACTAL_NZ) {
    reformatedTensor = l0op::Contiguous(tensor, executor);
    CHECK_COND(reformatedTensor != nullptr, ACLNN_ERR_INNER_NULLPTR, "%s Contiguous failed.", tensorName.c_str());

    bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P;
    // 910B不需要转NZ，直接return
    if (!is310P) {
      return ACLNN_SUCCESS;
    }

    reformatedTensor = l0op::TransData(reformatedTensor, Format::FORMAT_FRACTAL_NZ, 1, executor);
    CHECK_COND(reformatedTensor != nullptr, ACLNN_ERR_INNER_NULLPTR, "%s TransData failed.", tensorName.c_str());
    return ACLNN_SUCCESS;
  }
  reformatedTensor = tensor;
  return ACLNN_SUCCESS;
}

aclnnStatus OutputransFormat(const FFNParams &ffnParams, const aclTensor *tensor,
                             const aclTensor *&reformatedTensor, aclOpExecutor* executor) {
  // 剩余场景为tensor为y的场景
  op::Format requiredFormat = ffnParams.y->GetStorageFormat();
  bool is310P = GetCurrentPlatformInfo().GetSocVersion() == SocVersion::ASCEND310P;
  reformatedTensor = l0op::TransData(tensor, requiredFormat, 1, executor);
  CHECK_COND(reformatedTensor != nullptr, ACLNN_ERR_INNER_NULLPTR, "y TransData failed.");
  return ACLNN_SUCCESS;
}

int64_t GetBatchSizeX(const aclTensor *x)
{
  int64_t bs = 1;
  size_t xDimNum = x->GetViewShape().GetDimNum();
  for (size_t i = 0; i < xDimNum - 1; i++) {
    bs *= x->GetViewShape().GetDim(i);
  }
  return bs;
}

aclnnStatus CheckExpertTokens(FFNParams &ffnParams)
{
  size_t numExperts = ffnParams.expertTokensArr->Size();
  int64_t bs = GetBatchSizeX(ffnParams.x);
  int64_t totalNumExpertTokens;

  if (ffnParams.tokensIndexFlag) {
    int64_t prevTokensOffset = 0;
    for (size_t i = 0; i < numExperts; i++) {
      int64_t tokensOffset = (*ffnParams.expertTokensArr)[i];
      CHECK_COND(tokensOffset >= 0, ACLNN_ERR_PARAM_INVALID,
        "Expert tokens index should not be smaller than zero, but %ld is found.", tokensOffset);
      int64_t tokens = tokensOffset - prevTokensOffset;
      CHECK_COND(tokens >= 0, ACLNN_ERR_PARAM_INVALID,
        "Expert tokens indices should not be decremental, but two consecutive numbers %ld and %ld are found.",
        prevTokensOffset, tokensOffset);
      prevTokensOffset = tokensOffset;
    }
    totalNumExpertTokens = (*ffnParams.expertTokensArr)[numExperts - 1];
  } else {
    totalNumExpertTokens = 0;
    for (size_t i = 0; i < numExperts; i++) {
      int64_t tokens = (*ffnParams.expertTokensArr)[i];
      CHECK_COND(tokens >= 0, ACLNN_ERR_PARAM_INVALID,
        "Expert tokens should not be smaller than zero, but %ld is found.", tokens);
      totalNumExpertTokens += tokens;
    }
  }
  CHECK_COND(totalNumExpertTokens == bs, ACLNN_ERR_PARAM_INVALID,
    "Total number of expert tokens should be equal to the product of all x dimensions excluding the last one. "
    "But they are %ld and %ld respectively.", totalNumExpertTokens, bs);
  return ACLNN_SUCCESS;
}

aclnnStatus ConvertExpertTokensAndCheckParams(FFNParams &ffnParams, aclOpExecutor *executor) {
  if (ffnParams.expertTokens != nullptr && ffnParams.expertTokens->GetViewShape().GetDim(0) == 0) {
    ffnParams.expertTokens = nullptr;
  }

  if (ffnParams.expertTokensArr != nullptr && ffnParams.expertTokensArr->Size() != 0) {
    aclnnStatus ret = CheckExpertTokens(ffnParams);
    CHECK_RET(ret == ACLNN_SUCCESS, ret);
    ffnParams.expertTokens = executor->ConvertToTensor(ffnParams.expertTokensArr, op::ToOpDataType(ACL_INT64));
  }

  aclnnStatus ret = CheckParam(ffnParams);
  CHECK_RET(ret == ACLNN_SUCCESS, ret);

  return ret;
}

aclnnStatus GetFFNResultByL0Api(FFNParams &ffnParams, const char *activation,
                                uint64_t *workspaceSize, aclOpExecutor **executor) {
  // 固定写法，创建OpExecutor
  auto uniqueExecutor = CREATE_EXECUTOR();
  CHECK_RET(uniqueExecutor.get() != nullptr, ACLNN_ERR_INNER_CREATE_EXECUTOR);

  // 固定写法，参数检查
  aclnnStatus ret = ConvertExpertTokensAndCheckParams(ffnParams, uniqueExecutor.get());
  CHECK_RET(ret == ACLNN_SUCCESS, ret);

  if (ffnParams.x->IsEmpty()) {
    *workspaceSize = 0;
    uniqueExecutor.ReleaseTo(executor);
    return ACLNN_SUCCESS;
  }

  const aclTensor* reformatedX = nullptr;
  const aclTensor* reformatedWeight1 = nullptr;
  const aclTensor* reformatedWeight2 = nullptr;
  ret = InputsContiguousAndTransFormat(ffnParams.x, reformatedX, "x", uniqueExecutor.get());
  CHECK_RET(ret == ACLNN_SUCCESS, ret);
  ret = InputsContiguousAndTransFormat(ffnParams.weight1, reformatedWeight1, "weight1", uniqueExecutor.get());
  CHECK_RET(ret == ACLNN_SUCCESS, ret);
  ret = InputsContiguousAndTransFormat(ffnParams.weight2, reformatedWeight2, "weight2", uniqueExecutor.get());
  CHECK_RET(ret == ACLNN_SUCCESS, ret);

  // when input tensor is from pytorch framework, its view_shape/ori_shape/storage_shape is different
  // view_shape is ND-dimension as we expected, but ori_shape/storage_shape is flattened.
  // Now storage_shape is used in tiling process to determine some scenarios like pergroup antiquant,
  // smooth quant and etc. So we call l0op::contiguous to keep them identical.
  // Currently quant_scale, antiquant_scale1,  antiquant_scale2 need to be processed.
  const aclTensor* contiQuantScale = nullptr;
  CHECK_COND(
    InputsContiguousAndTransFormat(
      ffnParams.scale, contiQuantScale, "scale", uniqueExecutor.get()) == ACLNN_SUCCESS,
    ACLNN_ERR_PARAM_INVALID, "Convert scale to contiguous tensor faled.");
  const aclTensor* contiAntiScale1 = nullptr;
  CHECK_COND(
    InputsContiguousAndTransFormat(
      ffnParams.antiquantScale1, contiAntiScale1, "antiquantScale1", uniqueExecutor.get()) == ACLNN_SUCCESS,
    ACLNN_ERR_PARAM_INVALID, "Convert antiquantScale1 to contiguous tensor faled.");
  const aclTensor* contiAntiScale2 = nullptr;
  CHECK_COND(
    InputsContiguousAndTransFormat(
      ffnParams.antiquantScale2, contiAntiScale2, "antiquantScale2", uniqueExecutor.get()) == ACLNN_SUCCESS,
    ACLNN_ERR_PARAM_INVALID, "Convert antiquantScale2 to contiguous tensor faled.");

  // 调用l0算子FFN进行计算
  DataType yDtype = ffnParams.y->GetDataType();
  auto ffnResult = l0op::FFN(reformatedX, reformatedWeight1, reformatedWeight2, ffnParams.expertTokens,
                             ffnParams.bias1, ffnParams.bias2, contiQuantScale, ffnParams.offset, ffnParams.deqScale1,
                             ffnParams.deqScale2, contiAntiScale1, contiAntiScale2,
                             ffnParams.antiquantOffset1, ffnParams.antiquantOffset2, activation,
                             ffnParams.innerPrecise, yDtype, ffnParams.tokensIndexFlag, uniqueExecutor.get());
  CHECK_RET(ffnResult != nullptr, ACLNN_ERR_INNER_NULLPTR);
  const aclTensor* reformatedFFNResult;
  ret = OutputransFormat(ffnParams, ffnResult, reformatedFFNResult, uniqueExecutor.get());
  CHECK_RET(ret == ACLNN_SUCCESS, ret);

  // 如果出参out是非连续Tensor，需要把计算完的连续Tensor转非连续
  auto viewCopyResult = l0op::ViewCopy(reformatedFFNResult, ffnParams.y, uniqueExecutor.get());
  CHECK_RET(viewCopyResult != nullptr, ACLNN_ERR_INNER_NULLPTR);

  // 固定写法，获取计算过程中需要使用的workspace大小
  *workspaceSize = uniqueExecutor->GetWorkspaceSize();
  uniqueExecutor.ReleaseTo(executor);
  return ACLNN_SUCCESS;
}

aclnnStatus TransDataTypeDeqscale(FFNParams &ffnParams)
{
  auto deqScale1 = const_cast<aclTensor *>(ffnParams.deqScale1);
  auto deqScale2 = const_cast<aclTensor *>(ffnParams.deqScale2);
  DataType yDtype = ffnParams.y->GetDataType();
  if (deqScale1 != nullptr && deqScale2 != nullptr) {
    OP_CHECK_DTYPE_NOT_SAME(deqScale1, deqScale2, return ACLNN_ERR_PARAM_INVALID);
  }
  if (deqScale1 != nullptr && yDtype == DataType::DT_FLOAT16) {
    DataType deqScale1Dtype = ffnParams.deqScale1->GetDataType();
    if (deqScale1Dtype == DataType::DT_INT64) {
      deqScale1->SetDataType(op::DataType::DT_UINT64);
    }
  }
  if (deqScale2 != nullptr && yDtype == DataType::DT_FLOAT16) {
    DataType deqScale2Dtype = ffnParams.deqScale2->GetDataType();
    if (deqScale2Dtype == DataType::DT_INT64) {
      deqScale2->SetDataType(op::DataType::DT_UINT64);
    }
  }
  ffnParams.deqScale1 = deqScale1;
  ffnParams.deqScale2 = deqScale2;
  return ACLNN_SUCCESS;
}

aclnnStatus aclnnFFNGetWorkspaceSize(const aclTensor *x, const aclTensor *weight1, const aclTensor *weight2,
                                     const aclIntArray *expertTokens, const aclTensor *bias1, const aclTensor *bias2,
                                     const aclTensor *scale, const aclTensor *offset, const aclTensor *deqScale1,
                                     const aclTensor *deqScale2, const aclTensor *antiquantScale1,
                                     const aclTensor *antiquantScale2, const aclTensor *antiquantOffset1,
                                     const aclTensor *antiquantOffset2, const char *activation, int64_t innerPrecise,
                                     const aclTensor *y, uint64_t *workspaceSize, aclOpExecutor **executor)
{
  CHECK_COND(GetCurrentPlatformInfo().GetSocVersion() != SocVersion::ASCEND310P, ACLNN_ERR_PARAM_INVALID,
             "aclnnFFNGetWorkspaceSize and aclnnFFN are not supported on Ascend310P Soc. Please use aclnnFFNV2GetWorkspaceSize and aclnnFFNV2!");
  ActiveType activationType = GetActiveType(activation);
  CHECK_COND(activationType != ActiveType::INVALID_TYPE, ACLNN_ERR_PARAM_INVALID,
             "the activation type %s is not supported by ffn operator, please select right activation according to the document",
             activation);
  FFNParams ffnParams{x, weight1, weight2, expertTokens, nullptr, bias1, bias2, scale, offset, deqScale1, deqScale2,
            antiquantScale1, antiquantScale2, antiquantOffset1, antiquantOffset2,
            activationType, innerPrecise, false, y};
  aclnnStatus ret = TransDataTypeDeqscale(ffnParams);
  CHECK_RET(ret == ACLNN_SUCCESS, ret);
  L2_DFX_PHASE_1(aclnnFFN,
                 DFX_IN(ffnParams.x, ffnParams.weight1, ffnParams.weight2, expertTokens, ffnParams.bias1,
                        ffnParams.bias2, ffnParams.scale, ffnParams.offset, ffnParams.deqScale1, ffnParams.deqScale2,
                        ffnParams.antiquantScale1, ffnParams.antiquantScale2, ffnParams.antiquantOffset1,
                        ffnParams.antiquantOffset2, activation, ffnParams.innerPrecise, -1, false),
                 DFX_OUT(ffnParams.y));
  return GetFFNResultByL0Api(ffnParams, activation, workspaceSize, executor);
}

aclnnStatus aclnnFFN(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
{
  L2_DFX_PHASE_2(aclnnFFN);
  auto ret =  CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
  if (ret != ACLNN_SUCCESS) {
    OP_LOGE(ACLNN_ERR_INNER, "This is an error in FFN launch aicore");
    return ACLNN_ERR_INNER;
  }
  return ACLNN_SUCCESS;
}

aclnnStatus aclnnFFNV2GetWorkspaceSize(const aclTensor *x, const aclTensor *weight1, const aclTensor *weight2,
                                     const aclIntArray *expertTokens, const aclTensor *bias1, const aclTensor *bias2,
                                     const aclTensor *scale, const aclTensor *offset, const aclTensor *deqScale1,
                                     const aclTensor *deqScale2, const aclTensor *antiquantScale1,
                                     const aclTensor *antiquantScale2, const aclTensor *antiquantOffset1,
                                     const aclTensor *antiquantOffset2, const char *activation, int64_t innerPrecise,
                                     bool tokensIndexFlag, const aclTensor *y, uint64_t *workspaceSize,
                                     aclOpExecutor **executor)
{
  ActiveType activationType = GetActiveType(activation);
  CHECK_COND(activationType != ActiveType::INVALID_TYPE, ACLNN_ERR_PARAM_INVALID,
             "the activation type %s is not supported by ffn operator, please select right activation according to the document",
             activation);
  FFNParams ffnParams{x, weight1, weight2, expertTokens, nullptr, bias1, bias2, scale, offset, deqScale1, deqScale2,
            antiquantScale1, antiquantScale2, antiquantOffset1, antiquantOffset2,
            activationType, innerPrecise, tokensIndexFlag, y};
  aclnnStatus ret = TransDataTypeDeqscale(ffnParams);
  CHECK_RET(ret == ACLNN_SUCCESS, ret);
  L2_DFX_PHASE_1(aclnnFFNV2,
                 DFX_IN(ffnParams.x, ffnParams.weight1, ffnParams.weight2, expertTokens, ffnParams.bias1,
                        ffnParams.bias2, ffnParams.scale, ffnParams.offset, ffnParams.deqScale1, ffnParams.deqScale2,
                        ffnParams.antiquantScale1, ffnParams.antiquantScale2, ffnParams.antiquantOffset1,
                        ffnParams.antiquantOffset2, activation, ffnParams.innerPrecise, -1, tokensIndexFlag),
                 DFX_OUT(ffnParams.y));
  return GetFFNResultByL0Api(ffnParams, activation, workspaceSize, executor);
}

aclnnStatus aclnnFFNV2(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
{
  L2_DFX_PHASE_2(aclnnFFNV2);
  auto ret =  CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
  if (ret != ACLNN_SUCCESS) {
    OP_LOGE(ACLNN_ERR_INNER, "This is an error in FFN launch aicore");
    return ACLNN_ERR_INNER;
  }
  return ACLNN_SUCCESS;
}

aclnnStatus aclnnFFNV3GetWorkspaceSize(const aclTensor *x, const aclTensor *weight1, const aclTensor *weight2,
                                     const aclTensor *expertTokensOptional, const aclTensor *bias1Optional,
                                     const aclTensor *bias2Optional, const aclTensor *scaleOptional,
                                     const aclTensor *offsetOptional, const aclTensor *deqScale1Optional,
                                     const aclTensor *deqScale2Optional, const aclTensor *antiquantScale1Optional,
                                     const aclTensor *antiquantScale2Optional, const aclTensor *antiquantOffset1Optional,
                                     const aclTensor *antiquantOffset2Optional, const char *activation,
                                     int64_t innerPrecise, bool tokensIndexFlag, const aclTensor *y,
                                     uint64_t *workspaceSize, aclOpExecutor **executor)
{
  ActiveType activationType = GetActiveType(activation);
  CHECK_COND(activationType != ActiveType::INVALID_TYPE, ACLNN_ERR_PARAM_INVALID,
             "the activation type %s is not supported by ffn operator, please select right activation according to the document",
             activation);
  FFNParams ffnParams{x, weight1, weight2, nullptr, expertTokensOptional, bias1Optional, bias2Optional, scaleOptional,
                      offsetOptional, deqScale1Optional, deqScale2Optional, antiquantScale1Optional,
                      antiquantScale2Optional, antiquantOffset1Optional, antiquantOffset2Optional,
                      activationType, innerPrecise, tokensIndexFlag, y};
  aclnnStatus ret = TransDataTypeDeqscale(ffnParams);
  CHECK_RET(ret == ACLNN_SUCCESS, ret);
  L2_DFX_PHASE_1(aclnnFFNV3,
                 DFX_IN(ffnParams.x, ffnParams.weight1, ffnParams.weight2, expertTokensOptional, ffnParams.bias1,
                        ffnParams.bias2, ffnParams.scale, ffnParams.offset, ffnParams.deqScale1, ffnParams.deqScale2,
                        ffnParams.antiquantScale1, ffnParams.antiquantScale2, ffnParams.antiquantOffset1,
                        ffnParams.antiquantOffset2, activation, ffnParams.innerPrecise, -1, tokensIndexFlag),
                 DFX_OUT(ffnParams.y));
  return GetFFNResultByL0Api(ffnParams, activation, workspaceSize, executor);
}

aclnnStatus aclnnFFNV3(void *workspace, uint64_t workspaceSize, aclOpExecutor *executor, aclrtStream stream)
{
  L2_DFX_PHASE_2(aclnnFFNV3);
  auto ret =  CommonOpExecutorRun(workspace, workspaceSize, executor, stream);
  if (ret != ACLNN_SUCCESS) {
    OP_LOGE(ACLNN_ERR_INNER, "This is an error in FFN launch aicore");
    return ACLNN_ERR_INNER;
  }
  return ACLNN_SUCCESS;
}

#ifdef __cplusplus
}
#endif